import torch
from torch import optim
from torch import nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

import numpy as np
from jsputils import encoding
import scipy.stats as stats
from scipy.spatial.distance import pdist

from fastprogress import progress_bar
import gc
from IPython.core.debugger import set_trace

class SparsePositiveLasso(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SparsePositiveLasso, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        out = self.linear(x)
        return out

# Hyperparameters
learning_rate = 0.0001
max_epochs = 3000
batch_size = 250
weight_std = 0.0001
tol = 0.001
eval_freq = 250

device = 'cuda:1'
  
def fit_(encoder, layer, domain):
    # Convert numpy arrays to PyTorch tensors and send to GPU
    X = encoder.get_encoding_features(layer,domain)['train']
    y = encoder.get_encoding_voxels()['train']

    X_val = encoder.get_encoding_features(layer,domain)['val']
    y_val = encoder.get_encoding_voxels()['val']

    X_test = encoder.get_encoding_features(layer,domain)['test']
    y_test = encoder.get_encoding_voxels()['test']

    val_univar = np.mean(y_val,axis=1)
    val_rdv = pdist(y_val,'correlation')

    test_univar = np.mean(y_test,axis=1)
    test_rdv = pdist(y_test,'correlation')

    X_tensor = torch.from_numpy(X).float().to(device)
    y_tensor = torch.from_numpy(y).float().to(device)

    print(X_tensor.shape, y_tensor.shape)

    # Create PyTorch dataset and dataloader for mini-batch gradient descent
    dataset = torch.utils.data.TensorDataset(X_tensor, y_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers = 0)

    model = SparsePositiveLasso(X.shape[1], y.shape[1]).to(device)
    model.linear.weight.data.normal_(mean=0.0, std=weight_std)
    criterion = nn.MSELoss()  # Mean Squared Loss
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training
    # Tolerance for change in loss
    #stopping_counter = 0  # To keep track of number of batches for which change in loss is less than tol
    #previous_loss = np.inf  # Start with a high loss
    #consecutive_batch_limit = 20  # Stop if the change in loss is less than tol for these many consecutive batches

    previous_univar_corr = -np.inf
    previous_rdv_corr = -np.inf
    stop_flag = False

    for epoch in progress_bar(range(max_epochs)):
        for inputs, targets in dataloader:
            # Clear gradient buffers because we don't want any gradient from previous epoch to carry forward
            optimizer.zero_grad()

            # Get output given inputs
            outputs = model(inputs)

            # Get loss for the predicted output
            loss = criterion(outputs, targets)
            
            l1 = torch.sum(torch.abs(model.linear.weight)) * 0.00001
            
            loss += l1

            # Get gradients w.r.t parameters
            loss.backward()

            # Update parameters
            optimizer.step()

            # get rid of any negative weights
            model.linear.weight.data.clamp_(min=0)

    #         # Check the change in loss
    #         loss_change = abs(loss.item() - previous_loss)
    #         print(loss_change)
    #         if loss_change < tol:
    #             stopping_counter += 1
    #         else:
    #             stopping_counter = 0  # Reset the counter if the change in loss is more than tol

    #         previous_loss = loss.item()

    #        # If the stopping criterion is met, stop the training
    #        if stopping_counter >= consecutive_batch_limit:
    #            break

    #    if stopping_counter >= consecutive_batch_limit:
    #        print(f'Early stopping at epoch {epoch}, batch loss change less than {tol} for {consecutive_batch_limit} consecutive batches.')
    #        break

        if epoch % eval_freq == 0:
            print('Epoch:', epoch, 'Loss:', loss.item())

            # Get weights and bias (model parameters)
            weights = model.linear.weight.detach().cpu().numpy()
            bias = model.linear.bias.detach().cpu().numpy()

            preds = X_val @ weights.T + bias

            pred_univar = np.mean(preds,axis=1)
            pred_rdv = pdist(preds,'correlation')

            univar_corr = stats.pearsonr(val_univar, pred_univar)[0]
            rdv_corr = stats.pearsonr(val_rdv, pred_rdv)[0]

            print(univar_corr, rdv_corr)

            total_change = np.abs(previous_univar_corr - univar_corr) + np.abs(previous_rdv_corr - rdv_corr)

            if total_change < tol:
                if not stop_flag:
                    print('reducing learning rate')
                    for param_group in optimizer.param_groups:
                        param_group['lr'] *= 0.5   # Decrease learning rate by a factor of 10
                        stop_flag = True
                else:
                    break
            else:
                previous_univar_corr = univar_corr
                previous_rdv_corr = rdv_corr

    print('Training complete')

    # Get weights and bias (model parameters)
    weights = model.linear.weight.detach().cpu().numpy()
    bias = model.linear.bias.detach().cpu().numpy()

    test_preds = X_test @ weights.T + bias

    test_pred_univar = np.mean(test_preds,axis=1)
    test_true_univar = np.mean(y_test,axis=1)
    
    print(stats.pearsonr(test_pred_univar, test_true_univar))

    test_pred_rdv = pdist(test_preds,'correlation')
    test_true_rdv = pdist(y_test,'correlation')

    print(stats.pearsonr(test_pred_rdv, test_true_rdv))
    
    del model, inputs, targets, outputs, loss, optimizer, X_tensor, y_tensor, dataset
    torch.cuda.empty_cache()
    gc.collect()
    
    return weights, bias


# def fit(encoder, layer, domain):
#     # Convert numpy arrays to PyTorch tensors and send to GPU
#     X = encoder.get_encoding_features(layer,domain)['train']
#     y = encoder.get_encoding_voxels()['train']
    
#     X_tensor = torch.from_numpy(X).float().to(device)
#     y_tensor = torch.from_numpy(y).float().to(device)
    
#     # Create PyTorch dataset and dataloader for mini-batch gradient descent
#     dataset = torch.utils.data.TensorDataset(X_tensor, y_tensor)
#     dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

#     model = SparsePositiveLasso(X.shape[1], y.shape[1]).to(device)
#     model.linear.weight.data.normal_(mean=0.0, std=weight_std)
#     criterion = nn.MSELoss()  # Mean Squared Loss
#     optimizer = optim.Adam(model.parameters(), lr=learning_rate)

